{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Attention in Image Captioning.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/zaidalyafeai/AttentioNN/blob/master/Attention_in_Image_Captioning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fO_IJ2_Wr4NG",
        "colab_type": "text"
      },
      "source": [
        "## Introduction\n",
        "Traditional image captioning models architectures suffer from a bottleneck problem. Usually, we use a pretrained model to extract fixed features that are fed directly to an RNN model to generate the caption. However, this representation affects the captioning result as we progress in time because we look at the image as a whole not in parts. The basic idea behind attention is forcing the model to assign weights to different parts of the image which makes the captioning process more effective. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Et7c-GOEWeK-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install tensorflow-gpu==2.0.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y6-QTu27lQbG",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G4OCoreFk80h",
        "colab_type": "code",
        "outputId": "54c98bd9-f0f3-4163-85a2-54454a28f809",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "from sklearn.model_selection import train_test_split\n",
        "import matplotlib.pyplot as plt\n",
        "from tqdm import tqdm\n",
        "import glob\n",
        "import re\n",
        "import os\n",
        "import json \n",
        "import cv2\n",
        "import time\n",
        "import random\n",
        "from tensorflow.keras.applications.resnet50 import preprocess_input\n",
        "print(tf.__version__)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2.0.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HUKoNCqSNTkx",
        "colab_type": "text"
      },
      "source": [
        "## Dataset\n",
        "We use Ms COCO 2014"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y2jCKjPdiuQm",
        "colab_type": "code",
        "outputId": "c5f2ff5b-8984-401e-a81a-12381141cda7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        }
      },
      "source": [
        "annotation_zip = tf.keras.utils.get_file('captions.zip',\n",
        "                                          cache_subdir=os.path.abspath('.'),\n",
        "                                          origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n",
        "                                          extract = True)\n",
        "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n",
        "\n",
        "name_of_zip = 'train2014.zip'\n",
        "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n",
        "  image_zip = tf.keras.utils.get_file(name_of_zip,\n",
        "                                      cache_subdir=os.path.abspath('.'),\n",
        "                                      origin = 'http://images.cocodataset.org/zips/train2014.zip',\n",
        "                                      extract = True)\n",
        "  PATH = os.path.dirname(image_zip)+'/train2014/'\n",
        "else:\n",
        "  PATH = os.path.abspath('.')+'/train2014/'"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip\n",
            "252878848/252872794 [==============================] - 4s 0us/step\n",
            "Downloading data from http://images.cocodataset.org/zips/train2014.zip\n",
            "13510574080/13510573713 [==============================] - 246s 0us/step\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZVC-27_MNbi_",
        "colab_type": "text"
      },
      "source": [
        "## Feature Extraction Model\n",
        "\n",
        "We use ResNet 50 to extract the features of each image. We remove the last few layers. The choice of the output layer is for performance reasons. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xFnRX_CgvXm2",
        "colab_type": "code",
        "outputId": "1551ca1a-4048-4c26-80ce-75d5b52915a2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 71
        }
      },
      "source": [
        "image_model = tf.keras.applications.ResNet50(include_top=False,\n",
        "                                                weights='imagenet', input_shape = (224, 224, 3))\n",
        "new_input = image_model.input\n",
        "hidden_layer = image_model.get_layer('conv5_block3_2_relu').output\n",
        "\n",
        "feature_extraction_model = tf.keras.Model(new_input, hidden_layer)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading data from https://github.com/keras-team/keras-applications/releases/download/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5\n",
            "94773248/94765736 [==============================] - 2s 0us/step\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Zb_U-S0BPHiI",
        "colab_type": "code",
        "outputId": "e5cd2899-1012-4422-8423-7e3efdffad46",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "list(feature_extraction_model.outputs[0].shape)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[None, 7, 7, 512]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NgGNgJyrNza0",
        "colab_type": "text"
      },
      "source": [
        "## Extract and Save Features "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HqEZYRacmA8o",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "IMG_SIZE = 224\n",
        "\n",
        "# helper function to extract the image path and the npy path to save the features\n",
        "def create_img_npy_paths(imageid):\n",
        "  append_zeros = ('').join(['0']*(12-len(imageid))) +imageid\n",
        "  image_path = 'train2014/COCO_train2014_'+append_zeros+'.jpg'\n",
        "  npy_path = f'features/{imageid}.npy'\n",
        "  return image_path , npy_path \n",
        "\n",
        "# for each image we extract and save the features\n",
        "def load_featurize_save(imageid, save = True):\n",
        "  image_path, npy_path = create_img_npy_paths(imageid)\n",
        "\n",
        "  #load and preprocess\n",
        "  image = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)\n",
        "  image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])\n",
        "  image = preprocess_input(image)\n",
        "  image = tf.expand_dims(image, 0)\n",
        "\n",
        "  #extract the features \n",
        "  feature = feature_extraction_model(image)\n",
        "  if save:\n",
        "    np.save(npy_path, feature)\n",
        "  return tf.squeeze(feature, 0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gt7KdRA2lJv3",
        "colab_type": "code",
        "outputId": "6a5c6bf7-080a-49f9-f085-563b581f445c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# create a new foulder to save the features \n",
        "if not os.path.isdir('features'):\n",
        "  os.makedirs('features')\n",
        "\n",
        "feature_paths = []\n",
        "caps = []\n",
        "img_paths = []\n",
        "\n",
        "# load the annotations, we only use 10000 image, captions pair \n",
        "y = json.load(open(annotation_file, 'r'))\n",
        "annotations = np.random.choice(y['annotations'], 10000)    \n",
        "\n",
        "# loop over the images, load them and extract then save features \n",
        "for element in tqdm(annotations):\n",
        "  caption = element['caption']\n",
        "  imageid = str(element['image_id'])\n",
        "  load_featurize_save(imageid)\n",
        "  image_path, npy_path = create_img_npy_paths(imageid)\n",
        "  img_paths.append(image_path)\n",
        "  feature_paths.append(npy_path)\n",
        "  caps.append(caption)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 10000/10000 [15:45<00:00, 10.58it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e0GpDe37kO8D",
        "colab_type": "text"
      },
      "source": [
        "## Captions Preprocessing\n",
        "\n",
        "In this section we do some preprocessing. Our main task is to map each word to a unique index. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "teE5EQ4uiyJ6",
        "colab_type": "code",
        "outputId": "51c4df71-0756-499e-cbfc-706765f3d71d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "def preprocess_cap(stmt):\n",
        "  #remove new line character\n",
        "  stmt = stmt.replace(\"\\n\", \"\")\n",
        "  \n",
        "  #only keep alphanumerics\n",
        "  stmt = re.sub(r'([^\\s\\w]|_)+', \"\", stmt.lower().strip())\n",
        "  \n",
        "  #attach start, end special symbols \n",
        "  stmt = '<s> '+stmt+' <e>'\n",
        "  \n",
        "  return stmt\n",
        "\n",
        "processed_caps = [preprocess_cap(cap) for cap in caps]\n",
        "unique_words = set(' '.join(processed_caps).split(' '))\n",
        "num_words = len(unique_words)\n",
        "\n",
        "print('The number of unique words ', num_words)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "The number of unique words  5000\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r-HM2shaloAC",
        "colab_type": "text"
      },
      "source": [
        "Map each character to an integer using `tf.keras.preprocessing`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kHVexp-pkTcv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#helper function to find the largest statement in a corpus\n",
        "def get_max_stmt(stmts):\n",
        "  return max([len(stmt) for stmt in stmts])\n",
        "\n",
        "def get_tensors_dicts(stmts):\n",
        "\n",
        "  # how many words we use \n",
        "  most_frequent = 1000\n",
        "\n",
        "  #tokenzie using spaces and convert to integers \n",
        "  tk = tf.keras.preprocessing.text.Tokenizer(split = ' ', filters = \"\", num_words= most_frequent, oov_token = '<u>')\n",
        "  tk.fit_on_texts(stmts)\n",
        "\n",
        "  # create the word indices and indices to words\n",
        "  tk.word_index = {e:i for e,i in tk.word_index.items() if i < most_frequent} \n",
        "  tk.word_index['<p>'] = 0\n",
        "\n",
        "  word2index = tk.word_index\n",
        "  index2word = {word2index[k]:k for k in word2index.keys()}\n",
        "\n",
        "  # convert the text to sequences\n",
        "  sequences = tk.texts_to_sequences(stmts)\n",
        "\n",
        "  #pad the sequences to have the same length \n",
        "  max_stmt = get_max_stmt(sequences)\n",
        "  output = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen = max_stmt, padding = \"post\")  \n",
        "  \n",
        "  return output, word2index, index2word"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L08abKY_mI5L",
        "colab_type": "text"
      },
      "source": [
        "Get the input tensors and output tensors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0aqjlncsn2c3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "cap_tensors, word2index, index2word = get_tensors_dicts(processed_caps)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ctm4_VlTmOfP",
        "colab_type": "text"
      },
      "source": [
        "## Create Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nz7YzEl1oZKj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def parse_data(path, cap):\n",
        "  features = np.load(path.decode('utf-8'))\n",
        "  return features, cap"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kpVlpeOhpFjo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 64\n",
        "\n",
        "#random split\n",
        "paths_train, paths_valid, caps_train, caps_valid = train_test_split(feature_paths, cap_tensors, test_size=0.2)\n",
        "\n",
        "#training dataset\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((paths_train, caps_train)).shuffle(len(paths_train))\n",
        "train_dataset = train_dataset.map(lambda item1, item2: tf.numpy_function(\n",
        "          parse_data, [item1, item2], [tf.float32, tf.int32]),\n",
        "          num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "180JcjjMmpQz",
        "colab_type": "text"
      },
      "source": [
        "## Create Models\n",
        "Instintiate some variables"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2eiTtl-imwON",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "units = 1024\n",
        "embedding_dim = 256\n",
        "feature_vector_shape = (49, 512) \n",
        "vocab_size = len(index2word)\n",
        "caps_max_length = cap_tensors.shape[1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dcADa-3dm6Lv",
        "colab_type": "text"
      },
      "source": [
        "## Attention Mechanism\n",
        "\n",
        "Givn the output from the pretrained model as a list of features $a = [a_1, a_2, \\cdots , a_n]$ and the input hidden state $h_0$ are processed by a special network called the attention network. This results in attention weights which are values between 0 and 1 that tell us which hidden states are most important to us at each stage of the decoder. In this notebook use the following network \n",
        "\n",
        "$$\\text{Attention Network} = \\text{softmax}(V(\\tanh(W_1(a)+ W_2(h_0))))$$\n",
        "\n",
        "Where $W_1,W_2$ and $V$ are dense layers with $units, units$ and $1$ neurons respectively. This results in an output tensor of size $[\\text{batch_sz}, n, 1]$  called the attention_weights. Then the attention weights are multiplied element wise by $a$ to generate the context vector\n",
        "\n",
        "$$\\text{Context Vector} = \\text{attention_weights} \\odot a$$\n",
        "Finally the context vector is concatenated by the embedded input vector as an input to the decoder."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBKyplPfnBrH",
        "colab_type": "text"
      },
      "source": [
        "### Encoder"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zsUr6BWIqS_Y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def gru(units):\n",
        "  return tf.keras.layers.GRU(units, \n",
        "                             return_sequences=True, \n",
        "                             return_state=True, \n",
        "                             recurrent_initializer='glorot_uniform')\n",
        "\n",
        "def get_encoder(feature_vector_shape, embedding_dim, batch_sz):\n",
        "  \n",
        "    input = tf.keras.layers.Input(feature_vector_shape)\n",
        "    \n",
        "    # apply dense layer output x: [batch_sz, embedding_dim]\n",
        "    x = tf.keras.layers.Dense(embedding_dim, activation='tanh')(input)\n",
        "    \n",
        "    return tf.keras.models.Model(inputs = input, outputs = x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0lebNKA7nFg4",
        "colab_type": "text"
      },
      "source": [
        "### Decoder"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kOBqbMniraMF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_decoder(vocab_size, embedding_dim, units, batch_sz):\n",
        "  \n",
        "  enc_output = tf.keras.layers.Input((feature_vector_shape[0], embedding_dim))\n",
        "  enc_hidden = tf.keras.layers.Input((units,))\n",
        "  dec_input = tf.keras.layers.Input((1,))\n",
        "\n",
        "  W1 = tf.keras.layers.Dense(units)\n",
        "  W2 = tf.keras.layers.Dense(units)\n",
        "  V = tf.keras.layers.Dense(1)\n",
        "      \n",
        "  x = tf.keras.layers.Embedding(vocab_size, embedding_dim)(dec_input)\n",
        "  \n",
        "  #1. attention network output [batch_sz, feature_vector_size, 1]\n",
        "  score = V(tf.nn.tanh(W1(enc_output) + W2(tf.expand_dims(enc_hidden, axis = 1))))\n",
        "\n",
        "  #2. attention weights output [batch_sz, feature_vector_size , 1]\n",
        "  attention_weights = tf.nn.softmax(score, axis = 1)\n",
        "\n",
        "  #3. context_vector output [batch_sz, 1, units]\n",
        "  context_vector = attention_weights * enc_output\n",
        "  context_vector = tf.reduce_sum(context_vector, axis=1, keepdims = True)\n",
        "  \n",
        "  #3. concatenate with the output [batch_sz, 1, units + embedding_dim]\n",
        "  x = tf.concat([x, context_vector], axis = -1)\n",
        "  \n",
        "  #4. apply GRU output x:[batch_sz, 1, units] h:[batch_sz, units]\n",
        "  x, h = gru(units)(x)\n",
        "  \n",
        "  #5. reshape and dense output [batch_sz, vocab_size]\n",
        "  x = tf.reduce_sum(x, axis = 1)\n",
        "  output = tf.keras.layers.Dense(vocab_size)(x)\n",
        " \n",
        "  return tf.keras.models.Model(inputs = [dec_input, enc_hidden, enc_output], outputs = [output, h, attention_weights])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rBRZUsiW2Nhq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "encoder = get_encoder(feature_vector_shape, embedding_dim, BATCH_SIZE)\n",
        "decoder = get_decoder(vocab_size, embedding_dim, units, BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZKy5GkHknL_p",
        "colab_type": "text"
      },
      "source": [
        "## Loss function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Xl4X9GGr3KWW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "optimizer = tf.optimizers.Adam()\n",
        "\n",
        "def loss_function(real, pred):\n",
        "  # mask out the <u> and <p> tags because they don't contribute to captioning\n",
        "  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred, ) * (1 - np.equal(real, 0)) * (1 - np.equal(real, 1))\n",
        "  return tf.reduce_mean(loss)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pRsAuHdyoTmT",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3AF1lJol3VAO",
        "colab_type": "code",
        "outputId": "8d15a06c-797b-420f-858c-74c9e0d3e899",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 187
        }
      },
      "source": [
        "import time\n",
        "\n",
        "EPOCHS = 30\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "    start = time.time()\n",
        "    \n",
        "    total_loss = 0\n",
        "    \n",
        "    #loop over the training tensors \n",
        "    for (batch, (features, caps)) in enumerate(train_dataset):\n",
        "      \n",
        "        features = tf.reshape(features, (BATCH_SIZE, feature_vector_shape[0], feature_vector_shape[1]))\n",
        "\n",
        "        \n",
        "        loss = 0\n",
        "        \n",
        "        with tf.GradientTape() as tape:\n",
        "\n",
        "            # encode the features\n",
        "            enc_output = encoder(features)\n",
        "            enc_hidden = tf.zeros((BATCH_SIZE, units))\n",
        "\n",
        "            # create the initial input to the decoder \n",
        "            dec_input = tf.expand_dims([word2index['<s>']] * BATCH_SIZE, 1)      \n",
        "            \n",
        "\n",
        "            attention_sum = tf.zeros((BATCH_SIZE, 49, 1))\n",
        "\n",
        "            # Teacher forcing - feeding the target as the next input\n",
        "            for t in range(1, caps.shape[1]):\n",
        "              \n",
        "                # passing enc_output to the decoder\n",
        "                predictions, enc_hidden, attention_weights = decoder([dec_input, enc_hidden, enc_output])\n",
        "                attention_sum += attention_weights\n",
        "\n",
        "                # evaluate the of captioning \n",
        "                loss += loss_function(caps[:, t], predictions)   \n",
        "\n",
        "                # evaluate the next input \n",
        "                dec_input = tf.expand_dims(caps[:, t], 1)\n",
        "\n",
        "            # Doubly stochastic regularization https://arxiv.org/abs/1502.03044\n",
        "            # we want the sum of the weights accross t to sum to 1   \n",
        "            loss += 0.05 * tf.reduce_sum((1-attention_sum)**2)\n",
        "\n",
        "        #calculate the loss \n",
        "        batch_loss = (loss / int(caps.shape[1]))\n",
        "        total_loss += batch_loss\n",
        "        \n",
        "        # backprop\n",
        "        variables = encoder.variables + decoder.variables\n",
        "        gradients = tape.gradient(loss, variables)\n",
        "        optimizer.apply_gradients(zip(gradients, variables))\n",
        "        \n",
        "        N_BATCH = batch\n",
        "    \n",
        "    #show accumulative loss \n",
        "    print('Epoch {} Train Loss {:.4f}'.format(epoch + 1,\n",
        "                                        total_loss / N_BATCH))"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1 Train Loss 0.2154\n",
            "Epoch 2 Train Loss 0.1960\n",
            "Epoch 3 Train Loss 0.1778\n",
            "Epoch 4 Train Loss 0.1605\n",
            "Epoch 5 Train Loss 0.1451\n",
            "Epoch 6 Train Loss 0.1298\n",
            "Epoch 7 Train Loss 0.1166\n",
            "Epoch 8 Train Loss 0.1040\n",
            "Epoch 9 Train Loss 0.0953\n",
            "Epoch 10 Train Loss 0.0835\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6YYvP2cEoXNi",
        "colab_type": "text"
      },
      "source": [
        "## Test"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tm_qC-bI7q4Q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def postprocess_activations(activations):\n",
        "\n",
        "  #resize and convert to image \n",
        "  output = cv2.resize(activations, (224, 224), )\n",
        "  output = output/output.max() # maybe normalize\n",
        "  output = output *255\n",
        "  return 255 - output.astype('uint8')\n",
        "\n",
        "def apply_heatmap(weights, img):\n",
        "  #generate heat maps \n",
        "  heatmap = cv2.applyColorMap(weights, cv2.COLORMAP_JET)\n",
        "  heatmap = cv2.addWeighted(heatmap, 0.6, img, 0.4, 0)\n",
        "  return heatmap\n",
        "\n",
        "def plot_output(output, words):\n",
        "  fig = plt.figure(figsize = (20, 20))\n",
        "  n = len(output)\n",
        "  for i in range(n):\n",
        "    a=fig.add_subplot(1,n,i+1)\n",
        "    plt.subplots_adjust(wspace = 0.0005)\n",
        "    plt.imshow(output[i])\n",
        "    plt.axis('off')\n",
        "    plt.title(words[i])\n",
        "    \n",
        "def caption(imageid):\n",
        " \n",
        "    # load the image\n",
        "    image_path, npy_path = create_img_npy_paths(imageid)\n",
        "    numpy_image = cv2.imread(image_path)[:,:,::-1]\n",
        "    numpy_image = cv2.resize(numpy_image, (224, 224))\n",
        "    \n",
        "    # extract the features  \n",
        "    features = load_featurize_save(imageid, save = False)\n",
        "    features = tf.reshape(features, (1, feature_vector_shape[0], feature_vector_shape[1]))\n",
        "    \n",
        "    #feed encoder \n",
        "    enc_out = encoder(features)\n",
        "\n",
        "    enc_hidden = tf.zeros((1, units))\n",
        "    # prepare first input to the decoder \n",
        "    dec_input = tf.expand_dims([word2index['<s>']], 0)\n",
        "    \n",
        "    result = \"\"\n",
        "    words = ['']\n",
        "    output = [numpy_image.copy()]\n",
        "\n",
        "    for t in range(10):\n",
        "        \n",
        "        # feed decoder \n",
        "        predictions, enc_hidden, attention_weights = decoder([dec_input, enc_hidden, enc_out])\n",
        "\n",
        "        # extract the attention weights and post process them \n",
        "        attention_image = attention_weights.numpy().reshape((7,7))\n",
        "        heatmap = postprocess_activations(attention_image)\n",
        "\n",
        "        output.append(apply_heatmap(heatmap, numpy_image))\n",
        "\n",
        "        # predict next word \n",
        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
        "        \n",
        "        next_word = index2word[predicted_id]+ ' '\n",
        "        result += next_word + ' '\n",
        "        words.append(next_word)\n",
        "        \n",
        "        \n",
        "        # exit on end token \n",
        "        if index2word[predicted_id] == '<e>':\n",
        "            plot_output(output, words)\n",
        "            return result\n",
        "        \n",
        "        # the predicted ID is fed back into the model\n",
        "        dec_input = tf.expand_dims([predicted_id], 0)\n",
        "    \n",
        "    \n",
        "    plot_output(output, words)\n",
        "    return result"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ZqSjko68SDKx",
        "colab": {}
      },
      "source": [
        "for path in annotations[0:10]:\n",
        "  caption(str(path['image_id']))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rTTZ5IzEx9jG",
        "colab_type": "text"
      },
      "source": [
        "# References\n",
        "\n",
        "1. https://www.tensorflow.org/tutorials/text/image_captioning\n",
        "2. https://medium.com/syncedreview/a-brief-overview-of-attention-mechanism-13c578ba9129"
      ]
    }
  ]
}